# adversarialbox.attacks part=================================================================================
#PyTorch Implementation of Papernot's Black-Box Attack
#arXiv:1602.02697

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnF
import torch.optim as optim
from RobustDNN_PGD import clip_norm_, normalize_grad_, get_pgd_loss_fn_by_name
from Evaluate import cal_performance
#%%
def batch_indices(batch_nb, data_length, batch_size):
    """
    This helper function computes a batch start and end index
    :param batch_nb: the batch number
    :param data_length: the total length of the data being parsed by batches
    :param batch_size: the number of inputs in each batch
    :return: pair of (start, end) indices
    """
    # Batch start and end index
    start = int(batch_nb * batch_size)
    end = int((batch_nb + 1) * batch_size)

    # When there are not enough inputs left, we reuse some to complete the
    # batch
    if end > data_length:
        shift = end - data_length
        start -= shift
        end -= shift

    return start, end
#%%
def label_data(oracle, device, X_sub, batch_size):
    nb_batches = int(np.ceil(float(X_sub.shape[0]) / batch_size))
    Y_sub=np.zeros(X_sub.shape[0], dtype=np.int64)
    for batch in range(nb_batches):
        start, end = batch_indices(batch, len(X_sub), batch_size)
        X = X_sub[start:end]
        X = torch.from_numpy(X).to(device)
        Z = oracle(X)
        if len(Z.size()) <= 1:
            Y = (Z>0).to(torch.int64) #binary/sigmoid
        else:
            Y = Z.data.max(dim=1)[1] #multiclass/softmax
        Y_sub[start:end]=Y.data.cpu().numpy()
    return Y_sub
#%%
def jacobian_augmentation(sub_model, device, X_sub, Y_sub, nb_classes, batch_size, lmbda=0.1):
    """
    Create new numpy array for adversary training data
    with twice as many components on the first dimension.
    """
    sub_model.eval()
    X_aug = np.zeros(X_sub.shape, dtype=np.float32)
    nb_batches = int(np.ceil(float(len(X_sub))/batch_size))
    for batch in range(nb_batches):
        start, end = batch_indices(batch, len(X_sub), batch_size)
        X = X_sub[start:end]; X = torch.from_numpy(X).to(device)
        Yc = Y_sub[start:end]; Yg = torch.from_numpy(Yc).to(device)
        Xa= X_aug[start:end]
        for n in range(nb_classes):
            Xn=X[Yg==n].detach()
            if Xn.size(0) > 0:
                Xn.requires_grad=True
                Zn=sub_model(Xn)
                #Zn=nnF.softmax(Zn, dim=1), not necessary
                Zn=Zn[:,n].sum()
                #Zn.backward() will update dLdW
                Xn_grad=torch.autograd.grad(Zn, Xn)[0]
                Xn_aug=Xn.data+lmbda*Xn_grad.sign().data
                Xn_aug.clamp_(0, 1)
                Xa[Yc==n]=Xn_aug.data.cpu().numpy()
    #--------------------------------------------------------
    return X_aug
#%%
def train_sub_model(sub_model, oracle, device, dataloader, param, print_msg=False):
    #----------------------------------------------------
    data_iter = iter(dataloader)
    X_sub=[]
    initial_size=0
    while True:
        X_sub.append(data_iter.next()[0])
        initial_size+=X_sub[-1].size(0)
        if print_msg == True:
            print(initial_size)
        if  initial_size >= param['dataset_initial_size']:
            break
    X_sub=torch.cat(X_sub, dim=0)
    X_sub=X_sub.data.cpu().numpy().astype(np.float32)
    #----------------------------------------------------
    sub_model.to(device)
    sub_model.train()
    oracle.to(device)
    oracle.eval()
    #----------------------------------------------------
    # Setup training
    optimizer = optim.Adamax(sub_model.parameters(), lr=param['learning_rate'], weight_decay=0)
    # initial training data X_sub
    rng = np.random.RandomState(0)
    # Train the substitute SModel and augment dataset alternatively
    for rho in range(param['num_data_aug']+1):
        #Label the substitute training set
        Y_sub=label_data(oracle, device, X_sub, param['batch_size'])
        if print_msg == True:
            print(X_sub.shape)
        # train sub_model on X_sub, Y_sub
        sub_model.train()
        for epoch in range(param['num_epochs']):
            nb_batches = int(np.ceil(float(X_sub.shape[0]) / param['batch_size']))
            assert nb_batches * param['batch_size'] >= X_sub.shape[0]
            # Indices to shuffle training set
            index_shuf = list(range(X_sub.shape[0]))
            rng.shuffle(index_shuf)
            loss_epoch=0
            acc_epoch=0
            for batch in range(nb_batches):
                start, end = batch_indices(batch, X_sub.shape[0], param['batch_size'])
                X = X_sub[index_shuf[start:end]]; X = torch.from_numpy(X).to(device)
                Y = Y_sub[index_shuf[start:end]]; Y = torch.from_numpy(Y).to(device)
                Zs = sub_model(X)
                Yp = Zs.data.max(dim=1)[1] #multiclass/softmax
                loss = nnF.cross_entropy(Zs, Y)
                optimizer.zero_grad()
                loss.backward()
                weight_decay(optimizer,param['weight_decay'])
                optimizer.step()
                loss_epoch+=loss.item()
                acc_epoch+= torch.sum(Yp==Y).item()
            loss_epoch/=nb_batches
            acc_epoch/=X_sub.shape[0]
            if print_msg == True:
                print('Zs.abs().max()', Zs.abs().max().item())
                print('Train_Substitute: aug', rho, 'epoch', epoch, 'loss', loss_epoch, 'acc', acc_epoch)
        #Perform Jacobian-based dataset augmentation
        if rho == param['num_data_aug']:
            break
        sub_model.eval()
        if X_sub.shape[0] < param['dataset_max_size']:
            lmbda_new = 2 * int(int(rho / 3) != 0) - 1
            lmbda_new *= 0.1
            X_aug = jacobian_augmentation(sub_model, device, X_sub, Y_sub, param['num_classes'], param['batch_size'], lmbda=lmbda_new)
            X_sub = np.concatenate([X_sub, X_aug], axis=0)
        else:
            #augment half of the data
            X_sub_a = X_sub[index_shuf[0:int(X_sub.shape[0]/2)]]
            Y_sub_a = Y_sub[index_shuf[0:int(X_sub.shape[0]/2)]]
            X_aug = jacobian_augmentation(sub_model, device, X_sub_a, Y_sub_a, param['num_classes'], param['batch_size'])
            X_sub_b = X_sub[index_shuf[int(X_sub.shape[0]/2):]]
            #combine old and aug data
            X_sub = np.concatenate([X_sub_b, X_aug], axis=0)
#%%
def weight_decay(optimizer, rate):
    with torch.no_grad():
        for g in optimizer.param_groups:
            lr=g['lr']
            for p in g['params']:
                if p.requires_grad == True:
                    p -= lr*rate*p
#%%
def ifgsm_attack(sub_model, oracle, X, Y, noise_norm, norm_type=np.inf, max_iter=None, step=None,
                 targeted=False, clip_X_min=0, clip_X_max=1,
                 stop_if_label_change=False, use_optimizer=False, loss_fn=None):
    #https://arxiv.org/pdf/1607.02533v4.pdf
    if max_iter is None and step is None:
        max_iter, step = estimate_max_iter_and_step(noise_norm, norm_type, X=X)
    #set rand_init to False
    return pgd_attack(sub_model, oracle, X, Y, noise_norm, norm_type, max_iter, step,
                      False, None, targeted, clip_X_min, clip_X_max,
                      stop_if_label_change, use_optimizer, loss_fn)
#%% this is Projected Gradient Descent (PGD) attack
def pgd_attack(sub_model, oracle, X, Y, noise_norm, norm_type, max_iter, step,
               rand_init=True, rand_init_max=None, targeted=False, clip_X_min=0, clip_X_max=1,
               stop_if_label_change=False, use_optimizer=False, loss_fn=None):
    # X is in range of 0~1
    # noise is noise level
    #   it is L-inf norm (0~1) of max noise (Xn-X) if norm=np.inf
    #   it is L1 norm of max noise if norm=1
    #   it is L2 norm of max noise if norm=2
    # norm can be np.inf, 1, 2
    #-----------------------------------------------------
    loss_fn=get_pgd_loss_fn_by_name(loss_fn)
    #-----------------------------------------------------
    sub_model.eval()#set  to evaluation mode
    oracle.eval()#set to evaluation mode
    X = X.detach()
    #-----------------
    if stop_if_label_change == True:
        Z=oracle(X)
        Z=Z.detach()# cut graph
        if len(Z.size()) <= 1:
            Yp = (Z.data>0).to(torch.int64)
        else:
            Yp = Z.data.max(dim=1)[1]
        Yp_e_Y = (Yp==Y)
    #-----------------
    if rand_init == True:
        init_value=rand_init_max
        if rand_init_max is None:
            init_value=noise_norm
        noise_init=init_value*(2*torch.rand_like(X)-1)
        clip_norm_(noise_init, norm_type, noise_norm)
        Xn = X + noise_init
    else:
        Xn = X.clone().detach() # must clone
    #-----------------
    noise_new=(Xn-X).detach()
    if use_optimizer == True:
        optimizer = optim.Adamax([noise_new], lr=step)
    #-----------------
    for n in range(0, max_iter):
        Xn = Xn.detach()
        Xn.requires_grad = True
        Zn = sub_model(Xn)
        loss = loss_fn(Zn, Y)
        if targeted == True:#targeted attack, Y should be filled with targeted class label
            loss=-loss
        #loss.backward() will update dLdW
        Xn_grad=torch.autograd.grad(loss, Xn)[0]
        normalize_grad_(Xn_grad, norm_type)
        if use_optimizer == False:
            Xnew = Xn.detach() + step*Xn_grad.detach()
            noise_new = Xnew-X
        else:
            noise_new.grad=-Xn_grad.detach() #grad ascend
            optimizer.step()
        clip_norm_(noise_new, norm_type, noise_norm)
        #-------------------------------
        Zo = oracle(Xn.detach())
        if len(Zo.size()) <= 1:
            Ypo = (Zo.data>0).to(torch.int64)
        else:
            Ypo = Zo.data.max(dim=1)[1]
        #---------------------------
        if stop_if_label_change == True:
            if targeted == True:
                candidate=(Ypo!=Y)
            else:
                candidate=(Yp_e_Y)&(Ypo==Y)
            Xn=Xn.detach()
            Xn[candidate]=X[candidate] + noise_new[candidate]
        else:
            Xn = X+noise_new
        #---------------------------
        Xn = torch.clamp(Xn, clip_X_min, clip_X_max)
        noise_new.data -= noise_new.data-(Xn-X).data
        #---------------------------
        Xn=Xn.detach()
        #print('n=', n)
    return Xn
#%%
def test_adv(sub_model, oracle, device, dataloader, num_classes, noise_norm, norm_type, max_iter, step,  method,
             targeted=False, clip_X_min=0, clip_X_max=1,
             stop_if_label_change=True, use_optimizer=False, adv_loss_fn=None,
             save_model_output=False, class_balanced_acc=False):
    sub_model.eval()#set model to evaluation mode
    oracle.eval()#set model to evaluation mode
    #----------------------------------------------------
    confusion_clean=np.zeros((num_classes,num_classes))
    confusion_noisy=np.zeros((num_classes,num_classes))
    sample_count=0
    adv_sample_count=0
    sample_idx_wrong=[]
    sample_idx_attack=[]
    if save_model_output == True:
        y_list=[]
        z_list=[]
        yp_list=[]
        adv_z_list=[]
        adv_yp_list=[]
    #---------------------
    print('testing robustness bba1', method)
    print('norm_type:', norm_type, ', noise_norm:', noise_norm, ', max_iter:', max_iter, ', step:', step, sep='')
    adv_loss_fn=get_pgd_loss_fn_by_name(adv_loss_fn)
    print('adv_loss_fn', adv_loss_fn)
    #---------------------
    for batch_idx, (X, Y) in enumerate(dataloader):
        X, Y = X.to(device), Y.to(device)
        #------------------
        Z = oracle(X)
        if method == 'ifgsm':
            Xn = ifgsm_attack(sub_model, oracle, X, Y, noise_norm, norm_type, max_iter, step,
                              targeted, clip_X_min, clip_X_max,
                              stop_if_label_change=stop_if_label_change,
                              use_optimizer=use_optimizer, loss_fn=adv_loss_fn)
        elif method == 'pgd':
            Xn = pgd_attack(sub_model, oracle, X, Y, noise_norm, norm_type, max_iter, step,
                            True, None, targeted, clip_X_min, clip_X_max,
                            stop_if_label_change=stop_if_label_change,
                            use_optimizer=use_optimizer, loss_fn=adv_loss_fn)
        else:
            raise NotImplementedError("other method is not implemented.")
        Zn = oracle(Xn)
        #------------------
        if len(Z.size()) <= 1:
            Yp = (Z.data>0).to(torch.int64)
            Ypn = (Zn.data>0).to(torch.int64)
        else:
            Yp = Z.data.max(dim=1)[1] #multiclass/softmax
            Ypn = Zn.data.max(dim=1)[1] #multiclass/softmax
        #------------------
        #do not attack x that is missclassified
        Ypn_ = Ypn.clone().detach()
        Zn_=Zn.clone().detach()
        if targeted == False:
            temp=(Yp!=Y)
            Ypn_[temp]=Yp[temp]
            Zn_[temp]=Z[temp]
        for i in range(0, confusion_noisy.shape[0]):
            for j in range(0, confusion_noisy.shape[1]):
                confusion_noisy[i,j]+=torch.sum((Y==i)&(Ypn_==j)).item()
        #------------------
        for i in range(0, confusion_clean.shape[0]):
            for j in range(0, confusion_clean.shape[1]):
                confusion_clean[i,j]+=torch.sum((Y==i)&(Yp==j)).item()
        #------------------
        for m in range(0,X.size(0)):
            idx=sample_count+m
            if Y[m] != Yp[m]:
                sample_idx_wrong.append(idx)
            elif Ypn[m] != Yp[m]:
                sample_idx_attack.append(idx)
        sample_count+=X.size(0)
        adv_sample_count+=torch.sum((Yp==Y)&(Ypn!=Y)).item()
        #------------------
        if save_model_output == True:
            y_list.append(Y.detach().to('cpu').numpy())
            z_list.append(Z.detach().to('cpu').numpy())
            yp_list.append(Yp.detach().to('cpu').numpy())
            adv_z_list.append(Zn_.detach().to('cpu').numpy())
            adv_yp_list.append(Ypn_.detach().to('cpu').numpy())
    #------------------
    #------------------
    acc_clean, sens_clean, prec_clean = cal_performance(confusion_clean, class_balanced_acc)
    acc_noisy, sens_noisy, prec_noisy = cal_performance(confusion_noisy, class_balanced_acc)
    result={}
    result['method']=method
    result['noise_norm']=noise_norm
    result['max_iter']=max_iter
    result['step']=step
    result['sample_count']=sample_count
    result['adv_sample_count']=adv_sample_count
    result['sample_idx_wrong']=sample_idx_wrong
    result['sample_idx_attack']=sample_idx_attack
    result['confusion_clean']=confusion_clean
    result['acc_clean']=acc_clean
    result['sens_clean']=sens_clean
    result['prec_clean']=prec_clean
    result['confusion_noisy']=confusion_noisy
    result['acc_noisy']=acc_noisy
    result['sens_noisy']=sens_noisy
    result['prec_noisy']=prec_noisy
    #------------------
    if save_model_output == True:
        y_list = np.concatenate(y_list, axis=0).squeeze().astype('int64')
        z_list=np.concatenate(z_list, axis=0).squeeze()
        yp_list = np.concatenate(yp_list, axis=0).squeeze().astype('int64')
        adv_z_list=np.concatenate(adv_z_list, axis=0).squeeze()
        adv_yp_list = np.concatenate(adv_yp_list, axis=0).squeeze().astype('int64')
        result['y']=y_list
        result['z']=z_list
        result['yp']=yp_list
        result['adv_z']=adv_z_list
        result['adv_yp']=adv_yp_list
    print('testing robustness bba1 ', method, ', adv%=', adv_sample_count/sample_count, sep='')
    print('norm_type:', norm_type, ', noise_norm:', noise_norm, ', max_iter:', max_iter, ', step:', step, sep='')
    print('acc_clean', result['acc_clean'], ', acc_noisy', result['acc_noisy'])
    print('sens_clean', result['sens_clean'])
    print('sens_noisy', result['sens_noisy'])
    print('prec_clean', result['prec_clean'])
    print('prec_noisy', result['prec_noisy'])
    return result